from abc import ABC, abstractmethod
from geartrain.utils import DEFAULT_CFG
class BaseDetPredictor(ABC):
    def __init__(self, cfg = DEFAULT_CFG, overrides=None, _callbacks=None) -> None:
        super().__init__()
    # def __init__(self, args):
        
        self.model_path = model_path
        self.device = device
        
        self.conf = conf
        self.iou = iou
        self.model = self._load_model()

    @abstractmethod
    def _load_model(self):
        """Abstract method to load the pose prediction model."""
        pass

    @abstractmethod
    def predict(self, image):
        """
        Abstract method to predict poses in an image.

        Parameters:
        image (numpy.ndarray): Input image for pose prediction.

        Returns:
        list: A list of predicted poses. pose:[x,y,conf]
        
        """
        pass

    def preprocess_image(self, image):
        """
        Preprocess input image before feeding it to the model.

        Parameters:
        image (numpy.ndarray): Input image.

        Returns:
        numpy.ndarray: Preprocessed image.
        """
        # Example: Implement image preprocessing steps (resize, normalize, etc.)
        preprocessed_image = image  # Placeholder for actual preprocessing steps
        return preprocessed_image

    def postprocess_prediction(self, prediction):
        """
        Postprocess raw model prediction.

        Parameters:
        prediction: Raw prediction output from the model.

        Returns:
        list: Processed pose predictions.
        """
        # Example: Implement postprocessing steps (filtering, coordinate transformation, etc.)
        processed_prediction = prediction  # Placeholder for actual postprocessing steps
        return processed_prediction

    def visualize_prediction(self, image, prediction):
        """
        Visualize pose prediction on the input image.

        Parameters:
        image (numpy.ndarray): Input image.
        prediction (list): Processed pose predictions.

        Returns:
        numpy.ndarray: Image with visualized pose predictions.
        """
        # Example: Implement visualization of pose predictions on the image
        visualized_image = image  # Placeholder for actual visualization steps
        return visualized_image

    def process_image(self, image):
        """
        Process input image to obtain pose predictions.

        Parameters:
        image (numpy.ndarray): Input image.

        Returns:
        list: List of predicted poses.
        """
        preprocessed_image = self.preprocess_image(image)
        raw_prediction = self.model.predict(preprocessed_image)
        processed_prediction = self.postprocess_prediction(raw_prediction)
        return processed_prediction
    def warmup(self, imgsz=(1, 3, 640, 640)):
        import numpy as np
        random_image = np.random.randint(0, 256, (640, 640, 3), dtype=np.uint8)
        self.predict(random_image)
        print("warmup success ✅")
        
        
